{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#coding:utf-8\n",
    "\n",
    "import math\n",
    "import pandas as pd\n",
    "from collections import defaultdict\n",
    "\n",
    "# 读取train.txt\n",
    "train = pd.read_csv('nameData/train.txt')\n",
    "test = pd.read_csv('nameData/test.txt')\n",
    "submit = pd.read_csv('nameData/sample_submit.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>name</th>\n",
       "      <th>gender</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>闳家</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>玉璎</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>于邺</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>越英</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>蕴萱</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id name  gender\n",
       "0   1   闳家       1\n",
       "1   2   玉璎       0\n",
       "2   3   于邺       1\n",
       "3   4   越英       0\n",
       "4   5   蕴萱       0"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 把数据分为男女两部分\n",
    "names_female = train[train['gender'] == 0]\n",
    "names_male = train[train['gender'] == 1]\n",
    "\n",
    "# totals用来存放训练集中女生、男生的总数\n",
    "totals = {'f': len(names_female), 'm': len(names_male)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 分别计算在所有女生（男生）的名字当中，某个字出现的频率。这一步相当于是计算 P(Xi|女生)和P(Xi|男生)\n",
    "frequency_list_f = defaultdict(int)\n",
    "for name in names_female['name']:\n",
    "    for char in name:\n",
    "        frequency_list_f[char] += 1. / totals['f']\n",
    "\n",
    "frequency_list_m = defaultdict(int)\n",
    "for name in names_male['name']:\n",
    "    for char in name:\n",
    "        frequency_list_m[char] += 1. / totals['m']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.004144009000562539\n"
     ]
    }
   ],
   "source": [
    "print(frequency_list_f['娟'])\n",
    "print(frequency_list_m['钢'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 考虑到预测集中可能会有汉字并没有出现在训练集中，所以我们需要对频率进行Laplace平滑\n",
    "def LaplaceSmooth(char, frequency_list, total, alpha=1.0):\n",
    "    count = frequency_list[char] * total\n",
    "    distinct_chars = len(frequency_list)\n",
    "    freq_smooth = (count + alpha ) / (total + distinct_chars * alpha)\n",
    "    return freq_smooth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_f = math.log(1 - train['gender'].mean())\n",
    "base_f += sum([math.log(1 - frequency_list_f[char]) for char in frequency_list_f])\n",
    "\n",
    "base_m = math.log(train['gender'].mean())\n",
    "base_m += sum([math.log(1 - frequency_list_m[char]) for char in frequency_list_m])\n",
    "\n",
    "bases = {'f': base_f, 'm': base_m}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def GetLogProb(char, frequency_list, total):\n",
    "    freq_smooth = LaplaceSmooth(char, frequency_list, total)\n",
    "    return math.log(freq_smooth) - math.log(1 - freq_smooth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ComputeLogProb(name, bases, totals, frequency_list_m, frequency_list_f):\n",
    "    logprob_m = bases['m']\n",
    "    logprob_f = bases['f']\n",
    "    for char in name:\n",
    "        logprob_m += GetLogProb(char, frequency_list_m, totals['m'])\n",
    "        logprob_f += GetLogProb(char, frequency_list_f, totals['f'])\n",
    "    return {'male': logprob_m, 'female': logprob_f}\n",
    "\n",
    "def GetGender(LogProbs):\n",
    "    return LogProbs['male'] > LogProbs['female']\n",
    "\n",
    "result = []\n",
    "for name in test['name']:\n",
    "    LogProbs = ComputeLogProb(name, bases, totals, frequency_list_m, frequency_list_f)\n",
    "    gender = GetGender(LogProbs)\n",
    "    result.append(int(gender))\n",
    "\n",
    "submit['gender'] = result\n",
    "\n",
    "submit.to_csv('my_NB_prediction.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n"
     ]
    }
   ],
   "source": [
    "LogProbs = ComputeLogProb('阿佳斯', bases, totals, frequency_list_m, frequency_list_f)\n",
    "gender = GetGender(LogProbs)\n",
    "print(int(gender))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "def testSex(name):\n",
    "    LogProbs = ComputeLogProb(name, bases, totals, frequency_list_m, frequency_list_f)\n",
    "    gender = GetGender(LogProbs)\n",
    "    return int(gender)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "男性： 34 \n",
      " 女性： 25\n"
     ]
    }
   ],
   "source": [
    "manCount = womanCount = 0\n",
    "with open('information/人物名称.txt','r') as f:\n",
    "    for line in f:\n",
    "        if testSex(line.split('\\n')[0]) == 1:\n",
    "            manCount += 1\n",
    "        else:\n",
    "            womanCount += 1\n",
    "print('男性：',manCount,'\\n','女性：',womanCount)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>name</th>\n",
       "      <th>pred</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>辰君</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>佳遥</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>淼剑</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>浩苳</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>俪妍</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id name  pred\n",
       "0   0   辰君     0\n",
       "1   1   佳遥     0\n",
       "2   2   淼剑     1\n",
       "3   3   浩苳     1\n",
       "4   4   俪妍     0"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test['pred'] = result\n",
    "test.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
